import numpy,os
from matplotlib import pyplot
from scipy import constants
from jqc import jqc_plot
from diatom import Hamiltonian
from sympy.physics.wigner import wigner_3j,wigner_9j
from matplotlib.collections import LineCollection
from matplotlib.colors import LogNorm,LinearSegmentedColormap
from matplotlib.ticker import (
    AutoLocator, AutoMinorLocator)
from matplotlib import gridspec

Consts = Hamiltonian.RbCs
jqc_plot.plot_style()
cwd = os.path.dirname(os.path.abspath(__file__))
fpath = os.path.dirname(cwd)

colours = jqc_plot.colours

grid=gridspec.GridSpec(2,6,width_ratios=[1,0.02,1,0.02,0.1,0.1])

colour_dict_twk_blue = {
    "red" : [(0.0,244/255,244/255),
            (0.6,0,0),
            (1.0,0,0)] ,
    "green" : [(0.0,234/255,234/255),
            (0.6,70/255.0,70/255.0),
            (1.0,70/255,70/255)],
    "blue" : [(0.0,168/255,168/255),
            (0.6,127/255,127/255),
            (1.0,127/255,127/255)]
}
colour_dict_twk_blue_alpha = colour_dict_twk_blue.copy()
colour_dict_twk_blue_alpha['alpha'] = ((0.0, 0.0,0.0),
                   (0.25, .5, .5),
                   (0.5, 1., 1.),
                   (1.0, 1.0, 1.0))


RbCs_map_twk_blue = LinearSegmentedColormap("RbCs_map_tweak_blue",
                                                colour_dict_twk_blue_alpha)
pyplot.register_cmap(cmap=RbCs_map_twk_blue)

colour_dict_twk_red = {
    "red" : [(0.0,244/255,244/255),
            (0.6,198/255,198/255),
            (1.0,198/255,198/255)] ,
    "green" : [(0.0,234/255,234/255),
            (0.6,62/255.0,62/255.0),
            (1.0,62/255,62/255)],
    "blue" : [(0.0,168/255,168/255),
            (0.6,98/255,98/255),
            (1.0,98/255,98/255)]
}

colour_dict_twk_red_alpha = colour_dict_twk_red.copy()
colour_dict_twk_red_alpha['alpha'] = ((0.0, 0.0,0.0),
                   (0.25, .5, .5),
                   (0.5, 1., 1.),
                   (1.0, 1.0, 1.0))


RbCs_map_twk_red = LinearSegmentedColormap("RbCs_map_tweak_red",
                                                colour_dict_twk_red_alpha)
pyplot.register_cmap(cmap=RbCs_map_twk_red)



def make_segments(x, y):
    '''
    Create list of line segments from x and y coordinates, in the correct format
    for LineCollection:
    an array of the form   numlines x (points per line) x 2 (x and y) array
    '''

    points = numpy.array([x, y]).T.reshape(-1, 1, 2)
    segments = numpy.concatenate([points[:-1], points[1:]], axis=1)

    return segments

def colorline(x, y, z=None, cmap=pyplot.get_cmap('copper'),
                norm=pyplot.Normalize(0.0, 1.0), linewidth=3, alpha=1.0,
                legend=False,ax=None):
    '''
    Plot a colored line with coordinates x and y
    Optionally specify colors in the array z
    Optionally specify a colormap, a norm function and a line width
    '''
    if ax == None:
        ax = pyplot.gca()

    # Default colors equally spaced on [0,1]:
    if z is None:
        z = numpy.linspace(0.0, 1.0, len(x))

    # Special case if a single number:
    if not hasattr(z, "__iter__"):#to check for numerical input -- this is a hack
        z = numpy.array([z])

    z = numpy.asarray(z)

    segments = make_segments(x, y)
    lc = LineCollection(segments, array=z, cmap=cmap, norm=norm,
                        linewidth=linewidth,zorder=1.25)

    ax.add_collection(lc)

    return lc

def dipolez(Nmax,d):
    ''' Generates the induced dipole moment for a Rigid rotor '''
    shape = numpy.sum(numpy.array([2*x+1 for x in range(0,Nmax+1)]))
    Dmat = numpy.zeros((shape,shape),dtype= numpy.complex)
    i =0
    j =0
    for N1 in range(0,Nmax+1):
        for M1 in range(N1,-(N1+1),-1):
            for N2 in range(0,Nmax+1):
                for M2 in range(N2,-(N2+1),-1):
                    Dmat[i,j]=d*numpy.sqrt((2*N1+1)*(2*N2+1))*(-1)**(M1)*\
                    wigner_3j(N1,1,N2,-M1,0,M2)*wigner_3j(N1,1,N2,0,0,0)
                    j+=1
            j=0
            i+=1
    return Dmat


def dipolep(Nmax,d):
    ''' Generates the induced dipole moment for a Rigid rotor '''
    shape = numpy.sum(numpy.array([2*x+1 for x in range(0,Nmax+1)]))
    Dmat = numpy.zeros((shape,shape),dtype= numpy.complex)
    i =0
    j =0
    for N1 in range(0,Nmax+1):
        for M1 in range(N1,-(N1+1),-1):
            for N2 in range(0,Nmax+1):
                for M2 in range(N2,-(N2+1),-1):
                    Dmat[i,j]=d*numpy.sqrt((2*N1+1)*(2*N2+1))*(-1)**(M1)*\
                    wigner_3j(N1,1,N2,-M1,+1,M2)*wigner_3j(N1,1,N2,0,0,0)
                    j+=1
            j=0
            i+=1
    return Dmat

def dipolem(Nmax,d):
    ''' Generates the induced dipole moment for a Rigid rotor '''
    shape = numpy.sum(numpy.array([2*x+1 for x in range(0,Nmax+1)]))
    Dmat = numpy.zeros((shape,shape),dtype= numpy.complex)
    i =0
    j =0
    for N1 in range(0,Nmax+1):
        for M1 in range(N1,-(N1+1),-1):
            for N2 in range(0,Nmax+1):
                for M2 in range(N2,-(N2+1),-1):
                    Dmat[i,j]=d*numpy.sqrt((2*N1+1)*(2*N2+1))*(-1)**(M1)*\
                    wigner_3j(N1,1,N2,-M1,-1,M2)*wigner_3j(N1,1,N2,0,0,0)
                    j+=1
            j=0
            i+=1
    return Dmat

h = constants.h

Nmax =5
I1 = 3/2
I2 = 7/2

indices =[]
for N in range(0,Nmax+1):
    for MN in range(N,-(N+1),-1):
        for MI1 in numpy.arange(I1,-(I1+1),-1):
            for MI2 in numpy.arange(I2,-(I2+1),-1):
                indices.append([N,MN,MI1,MI2])

I1_state = numpy.zeros(int(2*I1+1))

I1_state[0]=1

I2_state = numpy.zeros(int(2*I2+1))

I2_state[0]=1

I_State = numpy.kron(I1_state,I2_state)
'''
dz = dipolez(Nmax,1)
dp = dipolep(Nmax,1)
dm = dipolem(Nmax,1)

dz = numpy.kron(dz,numpy.kron(numpy.identity(int(2*I1+1)),
                    numpy.identity(int(2*I2+1))))
dp = numpy.kron(dp,numpy.kron(numpy.identity(int(2*I1+1)),
                    numpy.identity(int(2*I2+1))))
dm = numpy.kron(dm,numpy.kron(numpy.identity(int(2*I1+1)),
                    numpy.identity(int(2*I2+1))))
'''
Nvec,I1vec,I2vec = Hamiltonian.Generate_vecs(Nmax,I1,I2)
F = Nvec+I1vec+I2vec

Fz = F[2]
Hyperfine_energy = numpy.genfromtxt(fpath+"\\AC Stark Data\\Sorted\\1064\\Beta=0\\N5_B0_E300.csv",
                                    delimiter=',')

Intensity = 1e-7*Hyperfine_energy[0,:]
Hyperfine_energy = Hyperfine_energy[1:,:]

Nplotmax = 1
numberplot= numpy.sum([(2*x+1)*32 for x in range(Nplotmax+1)])
colours_fixed = [colours['red'],colours['grayblue'],colours['green'],
                colours['purple']]

fig = pyplot.figure("STARK")

ax_Stark0 = fig.add_subplot(grid[0,0])

ax_Stark1 = fig.add_subplot(grid[1,0],sharex=ax_Stark0)

k =0

try:
    dpi = numpy.genfromtxt(fpath+"\\AC Stark Data\\Sorted\\1064\\Beta=0\\Ncalc5_N1_TDMz.csv",delimiter=',',dtype=numpy.complex128)
    print("loaded TDM")

except IOError:
    Hyperfine_States = numpy.load(fpath+"\\AC Stark Data\\Sorted\\1064\\Beta=0\\N5_B0_E300_states.npy")
    print("Calculating dipole moments")
    dz = dipolez(Nmax,1)
    dz = numpy.kron(dz,numpy.kron(numpy.identity(int(2*I1+1)),
                    numpy.identity(int(2*I2+1))))
    dpi = numpy.einsum('ix,ij,jkx->kx',
    Hyperfine_States[:,0,:],dz,Hyperfine_States[:,32:numberplot+1,:])
    numpy.savetxt(fpath+"\\AC Stark Data\\Sorted\\1064\\Beta=0\\Ncalc5_N1_TDMz.csv",dpi,delimiter=',')
    print("Saved dipole moments to:"+fpath+"\\AC Stark Data\\Sorted\\1064\\Beta=0\\Ncalc5_N1_TDMz.csv")

try:
    dsp = numpy.genfromtxt(fpath+"\\AC Stark Data\\Sorted\\1064\\Beta=0\\Ncalc5_N1_TDMp.csv",delimiter=',',dtype=numpy.complex128)
    print("loaded TDM")

except IOError:
    Hyperfine_States = numpy.load(fpath+"\\AC Stark Data\\Sorted\\1064\\Beta=0\\N5_B0_E300_states.npy")
    print("Calculating dipole moments")
    dp = dipolep(Nmax,1)
    dp = numpy.kron(dp,numpy.kron(numpy.identity(int(2*I1+1)),
                    numpy.identity(int(2*I2+1))))
    dsp = numpy.einsum('ix,ij,jkx->kx',
    Hyperfine_States[:,0,:],dp,Hyperfine_States[:,32:numberplot+1,:])
    numpy.savetxt(fpath+"\\AC Stark Data\\Sorted\\1064\\Beta=0\\Ncalc5_N1_TDMp.csv",dsp,delimiter=',')
    print("Saved dipole moments to:"+fpath+"\\AC Stark Data\\Sorted\\1064\\Beta=0\\Ncalc5_N1_TDMp.csv")

try:
    dsm = numpy.genfromtxt(fpath+"\\AC Stark Data\\Sorted\\1064\\Beta=0\\Ncalc5_N1_TDMm.csv",delimiter=',',dtype=numpy.complex128)
    print("loaded TDM")

except IOError:
    Hyperfine_States = numpy.load(fpath+"\\AC Stark Data\\Sorted\\1064\\Beta=0\\N5_B0_E300_states.npy")
    print("Calculating dipole moments")
    dm = dipolem(Nmax,1)
    dm = numpy.kron(dm,numpy.kron(numpy.identity(int(2*I1+1)),
                    numpy.identity(int(2*I2+1))))
    dsm = numpy.einsum('ix,ij,jkx->kx',
    Hyperfine_States[:,0,:],dm,Hyperfine_States[:,32:numberplot+1,:])
    numpy.savetxt(fpath+"\\AC Stark Data\\Sorted\\1064\\Beta=0\\Ncalc5_N1_TDMm.csv",dsm,delimiter=',')
    print("Saved dipole moments to:"+fpath+"\\AC Stark Data\\Sorted\\1064\\Beta=0\\Ncalc5_N1_TDMm.csv")


d_sig = (numpy.abs(dsm)+numpy.abs(dsp))/2
d_pi = dpi

for i in range(numberplot):
    index = indices[i]
    mf = numpy.round(numpy.dot(numpy.conjugate(Hyperfine_States[:,i,1]),
                                numpy.dot(Fz,Hyperfine_States[:,i,1])).real,1)
    if index[0] ==1:

        ax_Stark0.plot(Intensity,1e-6*(Hyperfine_energy[i,:]-Hyperfine_energy[gs,:])/h-980,
                    color=colours['sand'],alpha=0.5,zorder=1.0)

        cl = colorline(Intensity,1e-6*(Hyperfine_energy[i,:]-Hyperfine_energy[gs,:])/h-980,numpy.abs(d_pi),
                        cmap='RbCs_map_tweak_blue',norm=LogNorm(1e-2,1.0),
                        linewidth=2.0,ax=ax_Stark0)
        cl2 = colorline(Intensity,1e-6*(Hyperfine_energy[i,:]-Hyperfine_energy[gs,:])/h-980,numpy.abs(d_sig),
                        cmap='RbCs_map_tweak_red',norm=LogNorm(1e-2,1.0),
                        linewidth=2.0,ax=ax_Stark0)
        ax_Stark1.plot(Intensity,1e-6*(Hyperfine_energy[i,:]-Hyperfine_energy[gs,:])/h-980,
                    color=colours['sand'],alpha=0.5,zorder=1.0)

        cl = colorline(Intensity,1e-6*(Hyperfine_energy[i,:]-Hyperfine_energy[gs,:])/h-980,numpy.abs(d_pi),
                        cmap='RbCs_map_tweak_blue',norm=LogNorm(1e-2,1.0),
                        linewidth=2.0,ax=ax_Stark1)
        cl2 = colorline(Intensity,1e-6*(Hyperfine_energy[i,:]-Hyperfine_energy[gs,:])/h-980,numpy.abs(d_sig),
                        cmap='RbCs_map_tweak_red',norm=LogNorm(1e-2,1.0),
                        linewidth=2.0,ax=ax_Stark1)
        if mf ==+5:
            k+=1
    elif index[0]==0:
        if mf ==5:
            k+=1
            gs = i
            print(gs)

ax_Stark0.set_xlim(0,10)


Hyperfine_energy = numpy.genfromtxt(fpath+"\\AC Stark Data\\Sorted\\1064\\Beta=90\\N5_B90_E300.csv",
                                    delimiter=',')


Intensity = 1e-7*Hyperfine_energy[0,:]
Hyperfine_energy = Hyperfine_energy[1:,:]
ax_Stark2 = fig.add_subplot(grid[0,2],sharey=ax_Stark0)
ax_Stark3 = fig.add_subplot(grid[1,2],sharey=ax_Stark1,sharex=ax_Stark2)
k =0
d = numpy.zeros(Hyperfine_States.shape[2],dtype='complex128')

try:
    dpi = numpy.genfromtxt(fpath+"\\AC Stark Data\\Sorted\\1064\\Beta=90\\Ncalc5_N1_TDMz.csv",delimiter=',',dtype=numpy.complex128)
    print("loaded TDM")

except IOError:
    Hyperfine_States = numpy.load(fpath+"\\AC Stark Data\\Sorted\\1064\\Beta=90\\N5_B90_E300_states.npy")
    print("Calculating dipole moments")
    dz = dipolez(Nmax,1)
    dz = numpy.kron(dz,numpy.kron(numpy.identity(int(2*I1+1)),
                    numpy.identity(int(2*I2+1))))
    dpi = numpy.einsum('ix,ij,jkx->kx',
    Hyperfine_States[:,0,:],dz,Hyperfine_States[:,32:numberplot+1,:])
    numpy.savetxt(fpath+"\\AC Stark Data\\Sorted\\1064\\Beta=90\\Ncalc5_N1_TDMz.csv",dpi,delimiter=',')
    print("Saved dipole moments to:"+fpath+"\\AC Stark Data\\Sorted\\1064\\Beta=90\\Ncalc5_N1_TDMz.csv")

try:
    dsp = numpy.genfromtxt(fpath+"\\AC Stark Data\\Sorted\\1064\\Beta=90\\Ncalc5_N1_TDMp.csv",delimiter=',',dtype=numpy.complex128)
    print("loaded TDM")

except IOError:
    Hyperfine_States = numpy.load(fpath+"\\AC Stark Data\\Sorted\\1064\\Beta=90\\N5_B90_E300_states.npy")
    print("Calculating dipole moments")
    dp = dipolep(Nmax,1)
    dp = numpy.kron(dp,numpy.kron(numpy.identity(int(2*I1+1)),
                    numpy.identity(int(2*I2+1))))
    dsp = numpy.einsum('ix,ij,jkx->kx',
    Hyperfine_States[:,0,:],dp,Hyperfine_States[:,32:numberplot+1,:])
    numpy.savetxt(fpath+"\\AC Stark Data\\Sorted\\1064\\Beta=90\\Ncalc5_N1_TDMp.csv",dsp,delimiter=',')
    print("Saved dipole moments to:"+fpath+"\\AC Stark Data\\Sorted\\1064\\Beta=90\\Ncalc5_N1_TDMp.csv")

try:
    dsm = numpy.genfromtxt(fpath+"\\AC Stark Data\\Sorted\\1064\\Beta=90\\Ncalc5_N1_TDMm.csv",delimiter=',',dtype=numpy.complex128)
    print("loaded TDM")

except IOError:
    Hyperfine_States = numpy.load(fpath+"\\AC Stark Data\\Sorted\\1064\\Beta=90\\N5_B90_E300_states.npy")
    print("Calculating dipole moments")
    dm = dipolem(Nmax,1)
    dm = numpy.kron(dm,numpy.kron(numpy.identity(int(2*I1+1)),
                    numpy.identity(int(2*I2+1))))
    dsm = numpy.einsum('ix,ij,jkx->kx',
    Hyperfine_States[:,0,:],dm,Hyperfine_States[:,32:numberplot+1,:])
    numpy.savetxt(fpath+"\\AC Stark Data\\Sorted\\1064\\Beta=90\\Ncalc5_N1_TDMm.csv",dsm,delimiter=',')
    print("Saved dipole moments to:"+fpath+"\\AC Stark Data\\Sorted\\1064\\Beta=90\\Ncalc5_N1_TDMm.csv")


d_sig = (numpy.abs(dsm)+numpy.abs(dsp))/2
d_pi = dpi

for i in range(numberplot):
    index = indices[i]
    mf = numpy.round(numpy.dot(numpy.conjugate(Hyperfine_States[:,i,1]),
                                numpy.dot(Fz,Hyperfine_States[:,i,1])).real,1)
    if index[0] ==1:

        T0 = (Hyperfine_energy[i,0]-Hyperfine_energy[gs,0])/h
        ax_Stark2.plot(Intensity,1e-6*(Hyperfine_energy[i,:]-Hyperfine_energy[gs,:])/h-980,
                    color=colours['sand'],alpha=0.5,zorder=1.0)

        cl = colorline(Intensity,1e-6*(Hyperfine_energy[i,:]-Hyperfine_energy[gs,:])/h-980,numpy.abs(d_pi),
                        cmap='RbCs_map_tweak_blue',norm=LogNorm(1e-2,1.0),
                        linewidth=2.0,ax=ax_Stark2)
        cl2 = colorline(Intensity,1e-6*(Hyperfine_energy[i,:]-Hyperfine_energy[gs,:])/h-980,numpy.abs(d_sig),
                        cmap='RbCs_map_tweak_red',norm=LogNorm(1e-2,1.0),
                        linewidth=2.0,ax=ax_Stark2)

        ax_Stark3.plot(Intensity,1e-6*(Hyperfine_energy[i,:]-Hyperfine_energy[gs,:])/h-980,
                    color=colours['sand'],alpha=0.5,zorder=1.0)

        cl = colorline(Intensity,1e-6*(Hyperfine_energy[i,:]-Hyperfine_energy[gs,:])/h-980,numpy.abs(d_pi),
                        cmap='RbCs_map_tweak_blue',norm=LogNorm(1e-2,1.0),
                        linewidth=2.0,ax=ax_Stark3)
        cl2 = colorline(Intensity,1e-6*(Hyperfine_energy[i,:]-Hyperfine_energy[gs,:])/h-980,numpy.abs(d_sig),
                        cmap='RbCs_map_tweak_red',norm=LogNorm(1e-2,1.0),
                        linewidth=2.0,ax=ax_Stark3)
        if mf ==+5:
            k+=1
    elif index[0]==0:
        if mf ==5:
            k+=1
            gs = i
            print(gs)
ax_Stark2.set_xlim(0,10)
ax_Stark2.text(0.01,1.03,"+980 MHz",fontsize=15,clip_on=False,
                transform=ax_Stark2.transAxes)
ax_Stark2.tick_params(labelleft=False)
ax_Stark3.tick_params(labelleft=False)

ax_Stark0.text(0.01,1.03,"+980 MHz",fontsize=15,clip_on=False,
                transform=ax_Stark0.transAxes)
ax_Stark0.text(-0.23,0,"Transition Frequency (MHz)",clip_on=False,
                transform=ax_Stark0.transAxes,rotation=90,
                verticalalignment='center',horizontalalignment='center')

ax_Stark0.text(0.5,0.85,"$\\beta=0^\circ$",fontsize=15,transform=ax_Stark0.transAxes)

ax_Stark2.text(0.5,0.85,"$\\beta=90^\circ$",fontsize=15,transform=ax_Stark2.transAxes)

ax_Stark0.tick_params(labelbottom=False)
ax_Stark2.tick_params(labelbottom=False)

ax_Stark1.set_xlabel("Intensity, $I$ (kW$\\,$cm$^{-2}$)")
ax_Stark3.set_xlabel("Intensity, $I$ (kW$\\,$cm$^{-2}$)")

colax=fig.add_subplot(grid[:,4])
colax.set_title("$d_z$",fontsize=15)
fig.colorbar(cl,cax = colax)
colax.tick_params(labelright=False,right=False,which='both')

ax_Stark0.set_ylim(18,21)
ax_Stark1.set_ylim(8,11)

colax=fig.add_subplot(grid[:,5])
colax.set_title("$d_\\pm$",fontsize=15)
fig.colorbar(cl2,cax = colax)

pyplot.tight_layout()

pyplot.subplots_adjust(hspace=0.1,wspace=0.1,top=.93,right=0.9,left=0.14)

pyplot.savefig("fig3.pdf")
pyplot.savefig("fig3.png")
pyplot.show()
